#!/usr/bin/python
# compute-best-mix.py
#
# Copyright (c) [2014-], Josef Robert Novak
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
#  modification, are permitted #provided that the following conditions
#  are met:
#
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above
#    copyright notice, this list of #conditions and the following
#    disclaimer in the documentation and/or other materials provided
#    with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
# OF THE POSSIBILITY OF SUCH DAMAGE.
#
# This is just a vanilla python port of the SRILM script
#   * compute-best-mix
#####################################
import re, math

def LoadPPLFile (pplfile) :
    """
      Load up the salient info from a -debug 2 PPL file
      generated by the SRILM ngram tool.
    """
    ppl_info = []

    for line in open (pplfile, "r") :
        line = line.strip()
        if line.startswith ("p(") :
            tok, prob = line.split ("=")
            probs = re.split (r"\s+", re.sub (r"[\[\]]", "", prob).strip())
            tok   = re.sub (r"^p\( ", "", tok)
            tok   = re.sub (r" \|.*$", "", tok)
            ppl_info.append ([tok, int(re.sub(r"gram", "",probs[0])), float(probs[2])])
            
    print len (ppl_info)
    print "\t", ppl_info[2]
    return ppl_info

class MixtureComputer () :
    """
      Python port of SRILM gawk tool 'compute-best-mix'
      Should produce the same result as:
        $ compute-best-mix ppl.fil1 ppl.file2 ... ppl.fileN
    """
    def __init__ (self, ppl_infos, lambdas=[], precision=0.001, verbose=False) :
        self.M_LN10    = 2.30258509299404568402
        self.logINF    = -320
        self.precision = precision
        self.ppl_infos = ppl_infos
        self.lambdas, self.priors  = self._init_lambdas (lambdas)
        self.max_iter  = 100
        self.post_totals = []
        self.nc          = len (self.ppl_infos) #Number of components
        self.log10priors = []

    def _init_lambdas (self, lambdas) :
        if len (lambdas) == 0 :
            lambdas = [1./len(self.ppl_infos) for l in xrange(len(self.ppl_infos))]
        lambda_sum = 0.0
        priors_  = [0.0 for l in lambdas]
        for i,l in enumerate (lambdas) :
            priors_[i] = l
            lambda_sum += l
        return lambdas, priors_

    def _sum_word (self, i) :
        log_posts_ = [self.log10priors[j] + self.ppl_infos[j][i][2] 
                      for j in xrange(self.nc)]
        log_sum_   = log_posts_[0]
        for log_post_ in log_posts_[1:] :
            log_sum_ = math.log (
                (math.pow (10, log_sum_) + math.pow (10, log_post_)),
                10)
        return log_sum_, log_posts_

    def OptimizeLambdas (self) :
        """
          So how does this actually work?  There is no explanation except
          the source where the original gawk script is concerned.

          It is basically a simple, iterative EM-like estimation procedure.

          1.  Load all the PPL results from the component models
          2.  Initialize the original mixture weights 
          3.  For each word in the test set, compute the lambda-scaled sum
              for each of the component models for this word.
              For example, 
                * word    = WORD1, 
                * models  = M1, M2, M3
                * lambdas = L1, L2, L3
              Compute log posteriors:  log10(LN) + WORD1
              Compute the log sum of the posteriors for this word.
          4.  Compute the per-model posterior totals
              This is the per-model log posterior from (3.)
              divided by the total (word-based) log sum from (3.)
          5.  Recompute the lambda priors, normalizing by the total
              number of (non-OOV) words in the test set
          6.  Finally, determine the actual, absolute change between
              the previous prior values, and the newly recomputed ones.
              If the values for any of the models is larger than the 
              precision threshold, and we have not reached the max
              number of iterations, return to Step 3.  
          The algorithm terminates when either the max-iters is reached
          or the total change for all models dips below the threshold.
        """
        have_converged = False
        iteration      = 0
        while not have_converged :
            iteration += 1
            log_like = 0.0
            post_totals = [0.0 for p in self.ppl_infos]
            self.log10priors = [math.log (self.priors[i], 10) 
                                for i in xrange(self.nc)]

            for i in xrange(len(self.ppl_infos[0])) :
                # Compute the sum for this word, across all components
                log_sum, log_posts = self._sum_word (i)
                log_like += log_sum
                for j in xrange(len(self.ppl_infos)) :
                    post_totals[j] += math.pow (10, log_posts[j] - log_sum)
                
            print iteration, \
                " ".join([str(x) for x in self.priors]), \
                math.pow (10, -log_like / len(self.ppl_infos[0]))

            have_converged = True
            for j in xrange(len(self.ppl_infos)) :
                last_prior      = self.priors[j]
                self.priors[j]  = post_totals[j] / len(self.ppl_infos[0])
                
                abs_change = abs (last_prior - self.priors[j])
                if abs_change > self.precision :
                    have_converged = False
            if iteration > self.max_iter :
                have_converged = True
        return


if __name__=="__main__" :
    import sys, argparse

    example = "USAGE: {0} --ppl ppl.1.txt,ppl.2.txt,ppl.3.txt".format (sys.argv[0])
    parser  = argparse.ArgumentParser (description = example)
    parser.add_argument ("--ppl",     "-p", help="List of ppl files from 'ngram'.", required=True)
    parser.add_argument ("--verbose", "-v", help="Verbose mode.", default=False, action="store_true")
    args = parser.parse_args ()

    pplfiles = args.ppl.split (",")
    pplinfos = []

    for f in pplfiles :
        pplinfos.append (LoadPPLFile (f))
    mixer = MixtureComputer (pplinfos)
    mixer.OptimizeLambdas ()
    print mixer.priors
